'''
Author: SlytherinGe
LastEditTime: 2021-08-16 20:53:01
'''
import numpy as np
import os
import time
import sys
import xml.etree.ElementTree as ET

def is_small_box_in_big(small_box, big_box):
    '''
    smal_box:tuple(xmin, ymin, xmax, ymax)
    big_box:tuple(xmin, ymin, xmax, ymax)
    '''
    if small_box[0] < big_box[0]:
        return False
    if small_box[1] < big_box[1]:
        return False
    if small_box[2] > big_box[2]:
        return False
    if small_box[3] > big_box[3]:
        return False
    return True

def lefttop_rightdown_2_center_width(bbox):
    '''
    bbox: tuple(xmin, ymin, xmax, ymax)
    return: tuple(xmid, ymid, width, height)
    '''
    return ((bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2, bbox[2]-bbox[0], bbox[3]-bbox[1])

def center_width_2_lefttop_rightdown(bbox):
    '''
    bbox: tuple(xmid, ymid, width, height)
    return: tuple(xmin, ymin, xmax, ymax)
    '''
    return (bbox[0]-bbox[2]/2, bbox[1]-bbox[2]/2, bbox[0]+bbox[2]/2, bbox[1]+bbox[2]/2)

def clip_bbox(bbox, im_shape):
    '''
    make sure all the input bbox is in the range of the im_shape and return int number
    param: bbox: tuple(xmin, ymin, xmax, ymax)
    param: im_shape: tuple(im_width, im_height)
    return: the bbox in the im_shape
    '''
    xmin, ymin, xmax, ymax = bbox
    if bbox[0] < 0:
        xmin = 0
    if bbox[1] < 0:
        ymin = 0
    if bbox[2] >= im_shape[0]:
        xmax = im_shape[0] - 1
    if bbox[3] >= im_shape[1]:
        ymax = im_shape[1] - 1
    return (int(xmin), int(ymin), int(xmax), int(ymax))

def voc_basic_xml_element_generator(filename, folder, size_w, size_h, size_c):
    '''
    description: generat some common configs of a VOC label
    param {*}
    return {*}
    '''
    fn = ET.Element('filename')
    fn.text = filename
    fd = ET.Element('folder')
    fd.text = folder
    sz = ET.Element('size')

    width = ET.Element('width')
    width.text = str(size_w)
    height = ET.Element('height')
    height.text = str(size_h)
    depth = ET.Element('depth')
    depth.text = str(size_c)
    for ele in (width, height, depth):
        sz.append(ele)

    return (fn, fd, sz)

def voc_object_xml_element_generator(object_dict):
    '''
    description: convert the dict of a VOC object into a XML element
    param:
        object_dict: the dict to convert
    return:
        obj: converted element
    '''
    obj = ET.Element('object')
    for k, v in object_dict.items():
        obj_ele = ET.Element(k)
        if k != 'bndbox':
            obj_ele.text = v
        else:
            for k1, v1 in v.items():
                obj_cord = ET.Element(k1)
                obj_cord.text = str(v1)
                obj_ele.append(obj_cord)
        obj.append(obj_ele)
    return obj

class AOI_Generator(object):
    '''
    generate Area Of Interest
    '''
    def __init__(self, anno_root, out_root):
        self.anno_root = anno_root
        self.anno_files = os.listdir(self.anno_root)
        self.dataset_size = len(self.anno_files)
        self.anno = []
        self.aois = []
        print('loading annotations into memory...')
        tick = time.time()
        self.anno_files.sort()
        for i in range(self.dataset_size):
            self.anno.append(self.__load_voc_annotation(i))
        tock = time.time()
        print('annotations loaded! used time :{:.2f}s'.format(tock - tick))

        
    def __load_voc_annotation(self, index):
        '''
        for a given index, load its fundmental information (filename, size) and its bboxes
        return: a dict of related information
        '''
        assert index < self.dataset_size
        filename = os.path.join(self.anno_root, self.anno_files[index])
        tree = ET.parse(filename)
        anno_dict = dict()
        obj_list = []
        anno_dict['filename'] = tree.find('filename').text
        anno_dict['folder'] = tree.find('folder').text
        anno_dict['size'] = dict()
        anno_dict['size']['width'] = tree.find('size').find('width').text 
        anno_dict['size']['height'] = tree.find('size').find('height').text 
        anno_dict['size']['depth'] = tree.find('size').find('depth').text 

        objects = tree.findall("object")
        for obj in objects:
            bndbox = obj.find('bndbox')
            obj_dict = {
                'name' : obj.find('name').text,
                'truncated' : obj.find('truncated').text,
                'difficult': obj.find('difficult').text,
                'bndbox':{
                    'xmin':bndbox.find('xmin').text,
                    'ymin':bndbox.find('ymin').text,
                    'xmax':bndbox.find('xmax').text,
                    'ymax':bndbox.find('ymax').text,
                }
            }
            obj_list.append(obj_dict)

        anno_dict['object'] = obj_list
        return anno_dict

    def __pygenerate(self, anno_index, aoi_size):
        # buffer to store valid bboxes (whoes size are valid)
        bboxes = [] 
        # aoi: list of dict:[{'aoi':(xmin, ymin, xmax, ymax),'bbox':[(xmin, ymin, xmax, ymax),(xmin, ymin, xmax, ymax),...]},...]
        aois = []
        anno = self.anno[anno_index]
        im_h = int(anno['size']['height'])
        im_w = int(anno['size']['width'])
        for bndbox in anno['object']:
            bbox = (int(bndbox['bndbox']['xmin']), int(bndbox['bndbox']['ymin']), int(bndbox['bndbox']['xmax']), int(bndbox['bndbox']['ymax']))
            # disgarde bboxes that are oversized
            if (bbox[2] - bbox[0]) > aoi_size:
                continue
            if (bbox[3] - bbox[1]) > aoi_size:
                continue
            bboxes.append(bbox)
        # allocate all bboxes into aois
        while(len(bboxes)>0):
            temp = lefttop_rightdown_2_center_width(bboxes[0])
            aoi = (temp[0] - aoi_size/2, temp[1] - aoi_size/2, temp[0] + aoi_size/2, temp[1] + aoi_size/2)
            aoi_dict = {
                'aoi': aoi,
                'bbox' : list()
            }
            aoi_dict['bbox'].append(bboxes[0])
            del bboxes[0]
            for bbox in bboxes:
                if is_small_box_in_big(bbox, aoi_dict['aoi']):
                    aoi_dict['bbox'].append(bbox)
                    bboxes.remove(bbox)
            aois.append(aoi_dict)
        # refine aois. make the group center of bboxes in the center of aoi
        # and then make sure all the aois are in the range of the image
        for k, aoi in enumerate(aois):
            bbox_center = []
            cx, cy = 0, 0
            for bbox in aoi['bbox']:
                bbox_center.append(lefttop_rightdown_2_center_width(bbox)[:2])
            for center in bbox_center:
                cx += center[0]
                cy += center[1]
            cx /= len(bbox_center)
            cy /= len(bbox_center)
            aois[k]['aoi'] = (cx - aoi_size/2, cy - aoi_size/2, cx + aoi_size/2, cy + aoi_size/2)
            aois[k]['aoi'] = clip_bbox(aois[k]['aoi'], (im_w, im_h))

        return aois

    def __save_aois(self, index, save_path, save_origin_bbox=False):
        assert index < self.dataset_size
        root = ET.Element('annotation')
        tree = ET.ElementTree(root)
        basic_anno = voc_basic_xml_element_generator(self.anno[index]['filename'], self.anno[index]['folder'], self.anno[index]['size']['width'],
                                                        self.anno[index]['size']['height'], self.anno[index]['size']['depth'])
        for basic in basic_anno:
            root.append(basic)
        
        if save_origin_bbox:
            for dict_obj in self.anno[index]['object']:
                ele_obj = voc_object_xml_element_generator(dict_obj)
                root.append(ele_obj)
        
        for aoi in self.aois[index]:
            dict_aoi_obj = {
                'name' : 'aoi',
                'truncated' : '0',
                'difficult' : '0',
                'bndbox' : {
                    'xmin' : aoi['aoi'][0],
                    'ymin' : aoi['aoi'][1],
                    'xmax' : aoi['aoi'][2],
                    'ymax' : aoi['aoi'][3],
                }
            }
            ele_aoi_obj = voc_object_xml_element_generator(dict_aoi_obj)
            root.append(ele_aoi_obj)

        tree.write(save_path, encoding='utf-8', xml_declaration=True)


    def generate(self, aoi_size):
        for i in range(self.dataset_size):
            self.aois.append(self.__pygenerate(i, aoi_size))
        return self.aois

    def save_all(self, save_folder, save_origin_bbox=False):
        print('start saving aoi annotations....')
        if not os.path.exists(save_folder):
            os.mkdir(save_folder)
        tick = time.time()
        for i in range(self.dataset_size):
            self.__save_aois(i, os.path.join(save_folder, self.anno_files[i]), save_origin_bbox)
        tock = time.time()
        print('done saving! time used: {:.2f}s'.format(tock - tick))


    def clear_aois(self):
        self.aois = []

    def data_size(self):
        return len(self.anno_files)

    def get_anno_from_index(self, index):
        return self.anno[index]

    def get_aois_from_index(self, index):
        return self.aois[index]

    
if __name__ == '__main__':

    ANNO_ROOT = '/media/gejunyao/Disk1/Datasets/SARship-Total/PASCAL_VOC/VOC2012/Annotations/'
    OUT_ROOT = '/media/gejunyao/Disk/Gejunyao/develop/temp/big_anno'

    ag = AOI_Generator(ANNO_ROOT, OUT_ROOT)
    ag.generate(100)
    ag.save_all(OUT_ROOT, False)

    # print(ad)